package it.uniroma2.exp.task;

import it.uniroma2.tk.TreeKernel;
import it.uniroma2.util.tree.ArtificialTreeGenerator;
import it.uniroma2.util.tree.Tree;

import java.util.HashMap;

import edu.berkeley.compbio.jlibsvm.ImmutableSvmParameterPoint.Builder;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationProblemImpl;
import edu.berkeley.compbio.jlibsvm.binary.BinaryClassificationSVM;
import edu.berkeley.compbio.jlibsvm.binary.BinaryModel;
import edu.berkeley.compbio.jlibsvm.binary.C_SVC;

public class RTETaskTester {
	
	private ArtificialTreeGenerator atg = new ArtificialTreeGenerator(0);

	/**
	 * @param args
	 * @throws Exception 
	 */
	public static void main(String[] args) throws Exception {
		RTETaskTester att = new RTETaskTester();
		att.run();
	}
	
	@SuppressWarnings("unchecked")
	public void run() throws Exception {
		System.out.println("Generating training and testing set...");
		HashMap<Tree, Boolean> trainExamples = generateExamples(100);
		HashMap<Tree, Integer> exampleIds = new HashMap<Tree, Integer>(trainExamples.size());
		int i = 0;
		double positive = 0;
		for (Tree t : trainExamples.keySet()) {
			exampleIds.put(t, i);
			i++;
			if (trainExamples.get(t))
				positive++;
		}
		System.out.println("Positive examples in training set: "+positive/trainExamples.size()*100+"%");
		positive = 0;
		HashMap<Tree, Boolean> testExamples = generateExamples(100);
		for (Tree t : testExamples.keySet())
			if (testExamples.get(t))
				positive++;
		System.out.println("Positive examples in testing set: "+positive/testExamples.size()*100+"%");
		System.out.println("Building parameters...");
		BinaryClassificationSVM<Boolean, Tree> svm = new C_SVC<Boolean, Tree>();
		BinaryClassificationProblemImpl<Boolean, Tree> artificialProblem = new BinaryClassificationProblemImpl<Boolean, Tree>(Boolean.class, trainExamples, exampleIds);
		Builder<Boolean, Tree> svmParamBuilder = new Builder<Boolean, Tree>();
		svmParamBuilder.kernel = new TreeKernel();
		svmParamBuilder.eps = (float) 0.001;
		svmParamBuilder.cache_size = 128;
		System.out.println("Training...");
		long time = System.currentTimeMillis();
		BinaryModel<Boolean, Tree> model = svm.train(artificialProblem, svmParamBuilder.build());
		System.out.println((System.currentTimeMillis()-time)/1000.0+" seconds");
		System.out.println("Testing...");
		double acc = 0;
		for (Tree t : testExamples.keySet()) {
			if (model.predictLabel(t).equals(testExamples.get(t)))
				acc++;
		}
		System.out.println("Accuracy: "+(acc/testExamples.size())+"%");
	}
	
	private HashMap<Tree, Boolean> generateExamples(int number) throws Exception {
		HashMap<Tree, Boolean> examples = new HashMap<Tree, Boolean>();
		while(examples.size() < number) {
			Tree t = atg.generateRandomTree();
			examples.put(t, isPositiveExample(t));
		}
		return examples;
	}
	
	private boolean isPositiveExample(Tree t) {
		// A tree has positive label if it contains production C -> AB
		if (t.getRootLabel().equals("C") && t.getChildren().size() == 2 && 
			t.getChildren().get(0).getRootLabel().equals("A") && t.getChildren().get(1).getRootLabel().equals("B"))
			return true;
		else 
			for (Tree c : t.getChildren())
				if (isPositiveExample(c))
					return true;
		return false;
	}

}
